import torch
from torch.utils.data import Dataset, DataLoader
import glob
from sklearn.model_selection import train_test_split
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_squared_error, confusion_matrix
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import r2_score
import os
random.seed(2024)
np.random.seed(2024)
class CustomDataset(Dataset):
    def __init__(self, sequences1, labels):
        self.sequences1 = sequences1
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, index):
        sequence1 = self.sequences1[index]
        label = self.labels[index]
        
        return sequence1, label

def give_batch(path):
    train_x = []
    train_y = []
    test_size=0.3
    with open(path, 'r', encoding="utf-8") as f:
        data = f.readlines()
    
    for line in data:
        datas = line.strip().split('\t')[0]
        label = float(line.strip().split('\t')[1])
        lines = [float(x) for x in datas.strip().split(',')]
        dd = [0 if ddd=='' else ddd for ddd in lines]
        train_x.append(lines)
        train_y.append(label)
    X_train, X_test, y_train, y_test = train_test_split(train_x, train_y, test_size=test_size, random_state=42,shuffle=True)
    # print(X_train)
    # exit()
    return X_train, X_test, y_train, y_test

def heatmapVisual(model,dataloader,data_length,device,log_dir):
    accumulated_matrix = np.zeros((data_length, data_length))
    for x,y in dataloader:
        x = x.long().to(device)
        E = model.embedding1(x)
        A= torch.matmul(E, E.transpose(-2, -1))
        attention_matrix = model.activate(A).detach().cpu().numpy()
        for i in range(attention_matrix.shape[0]):
            current_matrix = attention_matrix[i]
            normalized_matrix = (current_matrix - np.min(current_matrix)) / (np.max(current_matrix) - np.min(current_matrix))
            accumulated_matrix += normalized_matrix
        np.savetxt(os.path.join(log_dir, 'accumulated_matrix.txt'), accumulated_matrix)
        # threshold = np.percentile(accumulated_matrix,70)
        # binary_accumulated_matrix1 = np.where(accumulated_matrix > threshold, 1, 0)
        plt.figure(figsize=(10, 8))
        sns.heatmap(accumulated_matrix, cmap='viridis')
        plt.title("Accumulated Normalized Attention Matrix Heatmap")
        plt.xlabel("Embedding Dimension")
        plt.ylabel("Sequence Length")
        plt.savefig(os.path.join(log_dir,'accumulated_attention_matrix_heatmap.png'), bbox_inches='tight')
        plt.close()
        column_sum1 = np.sum(accumulated_matrix, axis=0)
        normalized_column_sum1 = (column_sum1 - np.min(column_sum1)) / (np.max(column_sum1) - np.min(column_sum1))
        heatmap_matrix1 = np.expand_dims(normalized_column_sum1, axis=0)  
        plt.figure(figsize=(12, 2))  
        sns.heatmap(heatmap_matrix1, cmap='viridis', cbar=True, annot=True)
        plt.title("Column-wise Summed and Normalized Feature Heatmap")
        plt.xlabel("Embedding Dimension")
        plt.yticks([]) 
        plt.savefig(os.path.join(log_dir,'column_summed_normalized_heatmap.png'), bbox_inches='tight')
        plt.close()
        
        
def plot_r2_trends(R2_list,log_dir,prefix='train'):
    plt.figure(figsize=(30, 6))
    epochs = range(1, len(R2_list) + 1)
    plt.plot(epochs, R2_list, marker='o', linestyle='-', color='b')
    plt.xlabel('Epoch')
    plt.ylabel('R2 Score')
    plt.title(f"R2 {prefix} Score vs Epoch")
    plt.grid(True)
    lower_percentile = np.percentile(R2_list, 10)
    upper_percentile = np.percentile(R2_list, 90)
    
    y_min = max(0, lower_percentile - (upper_percentile - lower_percentile) * 0.5)
    y_max = min(1, upper_percentile + (upper_percentile - lower_percentile) * 0.5)
    
    plt.ylim(y_min, y_max)
    plt.savefig(os.path.join(log_dir,f"r2_{prefix}_score_vs_epoch.png"))
    plt.close()



def plotActualVsPredicted (pred,label,log_dir,prefix):
    plt.scatter(label, pred)
    plt.xlabel('True Values')
    plt.ylabel('Predictions')
    plt.title('True vs. Predicted Values')
    plt.savefig(os.path.join(log_dir,f"{prefix}_true_vs_predicted.png"), bbox_inches='tight')
    plt.close()
    
def getRegressionMetrics (model,dataloader,device,log_dir,prefix='train'):
    model.eval()
    all_preds = []
    all_labels = []
    for x_test, y_test in dataloader:
        x_test = x_test.long().to(device)
        y_test = y_test.float().unsqueeze(1).to(device)
        predictions = model(x_test)
        all_preds.extend(predictions.detach().cpu().numpy())
        all_labels.extend(y_test.cpu().numpy())
    # plotActualVsPredicted(all_preds,all_labels,log_dir,prefix)
    plotActualVsPredicted(all_preds,all_labels,log_dir,prefix)
    all_preds = [pred if pred >= 0 else abs(pred) for pred in all_preds]
    MSE = mean_squared_error(np.log(all_labels), np.log(all_preds))
    RMSE = np.sqrt(mean_squared_error(np.log(all_labels), np.log(all_preds)))
    MAE = mean_absolute_error(np.log(all_labels), np.log(all_preds))
    R2 = r2_score(all_labels,all_preds)
    # if all(pred > 0 for pred in all_preds):
        # MSE = mean_squared_error(np.log(all_labels), np.log(all_preds))
        # RMSE = np.sqrt(mean_squared_error(np.log(all_labels), np.log(all_preds)))
        # MAE = mean_absolute_error(np.log(all_labels), np.log(all_preds))
        # R2 = r2_score(all_labels,all_preds)
    # else:
        # MSE = mean_squared_error(all_labels,all_preds)
        # RMSE = np.sqrt(mean_squared_error(all_labels, all_preds))
        # MAE = mean_absolute_error(all_labels,all_preds)
        # R2 = r2_score(all_labels,all_preds)
    return MSE,RMSE,MAE,R2
def mean_and_se(values):
    mean = np.mean(values)
    se = np.std(values) / np.sqrt(len(values))
    return mean, se
